%matplotlib inline
import pandas as pd
import os
import glob
import pickle
import phate
import scprep
import meld
import graphtools as gt
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import time
import datetime
import scanpy as sc
from sklearn.decomposition import PCA
from py_pcha import PCHA
from mpl_toolkits.mplot3d import Axes3D
import matplotlib as mpl
# settings
plt.rc('font', size = 9)
plt.rc('font', family='sans serif')
plt.rcParams['pdf.fonttype']=42
plt.rcParams['ps.fonttype']=42
plt.rcParams['text.usetex']=False
plt.rcParams['legend.frameon']=False
plt.rcParams['axes.grid']=False
plt.rcParams['legend.markerscale']=0.5
sc.set_figure_params(dpi=300,dpi_save=600,
frameon=False,
fontsize=9)
plt.rcParams['savefig.dpi']=600
sc.settings.verbosity=2
sc._settings.ScanpyConfig.n_jobs=-1
sns.set_style("ticks")
dremi_data_files
pfp = '/home/ngr4/project/sccovid/results/'
dremi_data_files = [i for i in glob.glob(os.path.join(pfp,'*.csv')) if 'gdata' in i]
data = {os.path.split(f)[1].split('.csv')[0]:pd.read_csv(f, index_col=0) for f in dremi_data_files}
dremis_sep = data['gdata'].set_index('gene')
dremis = data['gdata_horizconcat'].set_index('gene')
del data
dremis
calc_pcaphate = False
n_AT = 8
if calc_pcaphate:
pca = PCA(n_components=100).fit(dremis)
dremis_pca = pca.transform(dremis)
# try PCHA
if False:
# on PCA space
X = np.array(dremis_pca) # on PCA space
else:
X = np.array(dremis) # on data
start=time.time()
XC, S, C, SSE, varexpl = PCHA(X.T, noc=n_AT) # S for each cell sum to 1
print('AA on sample in {:.2f}-min'.format((time.time()-start)/60))
if calc_pcaphate:
# plot on phate sapce
phate_op = phate.PHATE(n_components=3).fit(dremis)
dremis_phate = phate_op.transform(dremis)
# transform ATs
if True:
# PCHA in data-space
Y_pca = pca.transform(XC.T)
else:
# PCHA on PCA data
Y_pca = XC.T
Y_phate = phate_op.transform(XC.T)
fig, ax = plt.subplots(1,2,figsize=(6,2))
scprep.plot.scatter2d(dremis_pca,
ticks=None,
c='#f7c09c',
label_prefix='PCA',
ax=ax[0])
scprep.plot.scatter2d(dremis_phate,
ticks=None,
c='#f7c09c',
label_prefix='PHATE',
ax=ax[1])
p = sns.scatterplot(x=list(range(pca.explained_variance_ratio_.shape[0])),y=np.cumsum(pca.explained_variance_ratio_))
p.set_ylabel('Variance explained')
p.set_xlabel('PC')
p.set_title('Dimensionality of data')
# plot on PCA space
fig = plt.figure(figsize=(3, 2))
plt.scatter(dremis_pca[:,0], dremis_pca[:,1], s=3, alpha=0.5, c='#f7c09c')
plt.scatter([Y_pca[:,0]], [Y_pca[:,1]], s=200, c='#616066')
plt.xticks([])
plt.yticks([])
for i in range(Y_pca.shape[0]):
plt.text(Y_pca[i,0], Y_pca[i,1], i+1, horizontalalignment='center', verticalalignment='center', fontdict={'color': 'white','size':10,'weight':'bold'})
# plot on phate space
fig = plt.figure(figsize=(3, 2))
plt.scatter(dremis_phate[:,0], dremis_phate[:,1], s=3, alpha=0.5, c='#f7c09c', lw=0)
plt.scatter([Y_phate[:,0]], [Y_phate[:,1]], s=200, c='#616066')
plt.xticks([])
plt.yticks([])
for i in range(Y_phate.shape[0]):
plt.text(Y_phate[i,0], Y_phate[i,1], i+1, horizontalalignment='center', verticalalignment='center', fontdict={'color': 'white','size':10,'weight':'bold'})
# 3d plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(dremis_phate[:,0], dremis_phate[:,1], dremis_phate[:,2], c='skyblue', s=1, alpha=0.2)
ax.scatter(Y_phate[:,0], Y_phate[:,1], Y_phate[:,2], s=200, c='#616066')
for i in range(Y_phate.shape[0]):
ax.text(Y_phate[i,0], Y_phate[i,1], Y_phate[i,2], i+1, horizontalalignment='center', verticalalignment='center', fontdict={'color': 'white','size':10,'weight':'bold'})
# ax.view_init(30, 185)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.set_xlabel('PHATE1')
ax.set_ylabel('PHATE2')
ax.set_zlabel('PHATE3')
# try different ats
calc_pcaphate = False
n_AT = 6
if calc_pcaphate:
pca = PCA(n_components=100).fit(dremis)
dremis_pca = pca.transform(dremis)
# try PCHA
if False:
# on PCA space
X = np.array(dremis_pca) # on PCA space
else:
X = np.array(dremis) # on data
start=time.time()
XC, S, C, SSE, varexpl = PCHA(X.T, noc=n_AT) # S for each cell sum to 1
print('AA on sample in {:.2f}-min'.format((time.time()-start)/60))
if calc_pcaphate:
# plot on phate sapce
phate_op = phate.PHATE(n_components=3).fit(dremis)
dremis_phate = phate_op.transform(dremis)
# transform ATs
if True:
# PCHA in data-space
Y_pca = pca.transform(XC.T)
else:
# PCHA on PCA data
Y_pca = XC.T
Y_phate = phate_op.transform(XC.T)
fig, ax = plt.subplots(1,2,figsize=(6,2))
scprep.plot.scatter2d(dremis_pca,
ticks=None,
c='#f7c09c',
label_prefix='PCA',
ax=ax[0])
scprep.plot.scatter2d(dremis_phate,
ticks=None,
c='#f7c09c',
label_prefix='PHATE',
ax=ax[1])
# plot on phate space
fig = plt.figure(figsize=(3, 2))
plt.scatter(dremis_phate[:,0], dremis_phate[:,1], s=3, alpha=0.5, c='#f7c09c', lw=0)
plt.scatter([Y_phate[:,0]], [Y_phate[:,1]], s=200, c='#616066')
plt.xticks([])
plt.yticks([])
for i in range(Y_phate.shape[0]):
plt.text(Y_phate[i,0], Y_phate[i,1], i+1, horizontalalignment='center', verticalalignment='center', fontdict={'color': 'white','size':10,'weight':'bold'})
# 3d plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(dremis_phate[:,0], dremis_phate[:,1], dremis_phate[:,2], c='skyblue', s=1, alpha=0.2)
ax.scatter(Y_phate[:,0], Y_phate[:,1], Y_phate[:,2], s=200, c='#616066')
for i in range(Y_phate.shape[0]):
ax.text(Y_phate[i,0], Y_phate[i,1], Y_phate[i,2], i+1, horizontalalignment='center', verticalalignment='center', fontdict={'color': 'white','size':10,'weight':'bold'})
# ax.view_init(30, 185)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.set_xlabel('PHATE1')
ax.set_ylabel('PHATE2')
ax.set_zlabel('PHATE3')
# try different ats
calc_pcaphate = False
n_AT = 12
if calc_pcaphate:
pca = PCA(n_components=100).fit(dremis)
dremis_pca = pca.transform(dremis)
# try PCHA
if False:
# on PCA space
X = np.array(dremis_pca) # on PCA space
else:
X = np.array(dremis) # on data
start=time.time()
XC, S, C, SSE, varexpl = PCHA(X.T, noc=n_AT) # S for each cell sum to 1
print('AA on sample in {:.2f}-min'.format((time.time()-start)/60))
if calc_pcaphate:
# plot on phate sapce
phate_op = phate.PHATE(n_components=3).fit(dremis)
dremis_phate = phate_op.transform(dremis)
# transform ATs
if True:
# PCHA in data-space
Y_pca = pca.transform(XC.T)
else:
# PCHA on PCA data
Y_pca = XC.T
Y_phate = phate_op.transform(XC.T)
# plot on phate space
fig = plt.figure(figsize=(3, 2))
plt.scatter(dremis_phate[:,0], dremis_phate[:,1], s=3, alpha=0.5, c='#f7c09c', lw=0)
plt.scatter([Y_phate[:,0]], [Y_phate[:,1]], s=200, c='#616066')
plt.xticks([])
plt.yticks([])
for i in range(Y_phate.shape[0]):
plt.text(Y_phate[i,0], Y_phate[i,1], i+1, horizontalalignment='center', verticalalignment='center', fontdict={'color': 'white','size':10,'weight':'bold'})
# 3d plot
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(dremis_phate[:,0], dremis_phate[:,1], dremis_phate[:,2], c='skyblue', s=1, alpha=0.2)
ax.scatter(Y_phate[:,0], Y_phate[:,1], Y_phate[:,2], s=200, c='#616066')
for i in range(Y_phate.shape[0]):
ax.text(Y_phate[i,0], Y_phate[i,1], Y_phate[i,2], i+1, horizontalalignment='center', verticalalignment='center', fontdict={'color': 'white','size':10,'weight':'bold'})
# ax.view_init(30, 185)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.set_xlabel('PHATE1')
ax.set_ylabel('PHATE2')
ax.set_zlabel('PHATE3')
# scprep utils, REF: https://github.com/KrishnaswamyLab/scprep/blob/0d4843f35702d0d8778e7a6b843f8a4262048d60/scprep/plot/utils.py#L70
def _in_ipynb():
"""Check if we are running in a Jupyter Notebook
Credit to https://stackoverflow.com/a/24937408/3996580
"""
__VALID_NOTEBOOKS = [
"<class 'google.colab._shell.Shell'>",
"<class 'ipykernel.zmqshell.ZMQInteractiveShell'>",
]
try:
return str(type(get_ipython())) in __VALID_NOTEBOOKS
except NameError:
return False
def _get_figure(ax=None, figsize=None, subplot_kw=None):
if subplot_kw is None:
subplot_kw = {}
if ax is None:
if "projection" in subplot_kw and subplot_kw["projection"] == "3d":
# ensure mplot3d is loaded
Axes3D
fig, ax = plt.subplots(figsize=figsize, subplot_kw=subplot_kw)
show_fig = True
else:
try:
fig = ax.get_figure()
except AttributeError as e:
if not isinstance(ax, mpl.axes.Axes):
raise TypeError(
"Expected ax as a matplotlib.axes.Axes. " "Got {}".format(type(ax))
)
else:
raise e
if "projection" in subplot_kw:
if subplot_kw["projection"] == "3d" and not isinstance(
ax, Axes3D
):
raise TypeError(
"Expected ax with projection='3d'. " "Got 2D axis instead."
)
show_fig = False
return fig, ax, show_fig
def show(fig):
"""Show a matplotlib Figure correctly, regardless of platform
If running a Jupyter notebook, we avoid running `fig.show`. If running
in Windows, it is necessary to run `plt.show` rather than `fig.show`.
Parameters
----------
fig : matplotlib.Figure
Figure to show
"""
fig.tight_layout()
if _mpl_is_gui_backend():
if platform.system() == "Windows":
plt.show(block=True)
else:
fig.show()
# rotate 3d
filename=None
dpi=300
rotation_speed=30
fps=1
elev=None
figsize=None
ipython_html="jshtml"
ax=None
if _in_ipynb():
# in ipynb
# credit to
# http://tiao.io/posts/notebooks/save-matplotlib-animations-as-gifs/
mpl.rc("animation", html=ipython_html)
if filename is not None:
if filename.endswith(".gif"):
writer = "imagemagick"
elif filename.endswith(".mp4"):
writer = "ffmpeg"
else:
raise ValueError(
"filename must end in .gif or .mp4. Got {}".format(filename)
)
degrees_per_frame = rotation_speed / fps
frames = int(round(360 / degrees_per_frame))
# fix rounding errors
degrees_per_frame = 360 / frames
interval = 1000 * degrees_per_frame / rotation_speed
fig, ax, show_fig = _get_figure(ax, figsize, subplot_kw={"projection": "3d"})
# ax = fig.add_subplot(111, projection='3d')
ax.scatter(dremis_phate[:,0], dremis_phate[:,1], dremis_phate[:,2], c='skyblue', s=1, alpha=0.2)
ax.scatter(Y_phate[:,0], Y_phate[:,1], Y_phate[:,2], s=200, c='#616066')
for i in range(Y_phate.shape[0]):
ax.text(Y_phate[i,0], Y_phate[i,1], Y_phate[i,2], i+1, horizontalalignment='center', verticalalignment='center', fontdict={'color': 'white','size':10,'weight':'bold'})
# ax.view_init(30, 185)
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.set_xlabel('PHATE1')
ax.set_ylabel('PHATE2')
ax.set_zlabel('PHATE3')
azim = ax.azim
def init():
return ax
def animate(i):
ax.view_init(azim=azim + i * degrees_per_frame, elev=elev)
return ax
ani = mpl.animation.FuncAnimation(
fig,
animate,
init_func=init,
frames=range(frames),
interval=interval,
blit=False,
)
if filename is not None:
ani.save(filename, writer=writer, dpi=dpi)
if _in_ipynb():
# credit to https://stackoverflow.com/a/45573903/3996580
plt.close(fig)
elif show_fig:
show(fig)
ani
ani
_in_ipynb()